|
|||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | ||||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.ObjectmscJNeuralNet.trainer.NetTrainer
Diese Klasse dient zum Trainieren eines KNN. Mit Hilfe dieser Klasse kann das Trainieren des Netzes als nebenläufiger Prozess realisiert werden. Gleichzeitig stehen nicht-nebenläufige Methoden zur Verfügung, um ein Netz bis zu einem bestimmten Fehlerwert oder einem bestimmten Lernzyklus zu trainieren.
train(INetTrainingAlgorithm, double[][], double[][], int)
:
Trainieren bis eine bestimmte Anzahl an Lernschritten absolviert wurde.
train(INetTrainingAlgorithm, double[][], double[][], double, int)
:
Trainieren bis zu einem bestimmten Fehlerwert. Es werden alle Haupt-Fehlertypen
aus NetPerformanceStatistics
unterstützt.
train(INetTrainingAlgorithm, double[][], double[][], double, int, int)
:
Trainieren bis zu einem bestimmten Fehlerwert oder bis eine bestimmte
Anzahl an Lernschritten absolviert wurde.
stop()
oder durch erreichen bestimmter Bedingungen, wie dem Unterschreiten eines Fehlerwertes
oder das Absolvieren einer bestimmten Anzahl an Lernschritten angehalten werden.
Das Training kann durch die Schnittstelle Observable
überwacht werden, denn der NetTrainer benachrichtigt nach jedem Lernschritt alle gemeldeten
Observer. Im folgenden ein einfaches Beispiel, dass das dreidimensionale Xor-Problem als
nebenläufigen Prozess trainiert:package mscJNeuralNet.examples; import observerPattern.Observable; import observerPattern.Observer; import mscJNeuralNet.connectors.INetConnector; import mscJNeuralNet.connectors.RandomSymmetryBreakingNetConnector; import mscJNeuralNet.net.Net; import mscJNeuralNet.net.PatternDoesNotMatchNetException; import mscJNeuralNet.netPerformanceStatistics.NetPerformanceReporter; import mscJNeuralNet.netPerformanceStatistics.NetPerformanceStatistics; import mscJNeuralNet.patterns.Patterns; import mscJNeuralNet.trainer.NetTrainer; import mscJNeuralNet.trainingAlgorithms.INetTrainingAlgorithm; import mscJNeuralNet.trainingAlgorithms.RProp; public class TestTrippleXorLearnConcurrent implements Observer{ public TestTrippleXorLearnConcurrent() throws PatternDoesNotMatchNetException{ // 1. Erzeugen der benötigten Klassen // MLPCreated on 05.06.2004Net
// MLP mit der gewünschten Schichtstruktur erstellen: // Eingabeschicht: 3 Neuronen // 1. Verdeckte Schicht: 3 Neuronen // Ausgabeschicht: 1 Neuron int [] layerSizesTrippleXOr = {3, 3, 1}; //Net
myNet = new Net(layerSizesTrippleXOr); Net myNet = new Net(layerSizesTrippleXOr); // BIAS wurde automatisch berücksichtigt. //INetConnector
// Diese Klasse wird zum Initialisieren der Kantengewichte benötigt INetConnector lNetConnectionAlgo = new RandomSymmetryBreakingNetConnector(); //INetTrainingAlgorithm
// Diese Klasse wird zum Trainieren eines MLP benötigt INetTrainingAlgorithm lNetTrainAlgo = new RProp(); //NetTrainer
// Diese Klasse wird das Netz als nebenläufigen Prozess trainieren. NetTrainer lTrainer = new NetTrainer(); // 2. Initialisieren des Netzes // MLP +INetConnector
// Mit der Instanz vonINetConnector
inititalisieren. // Die KlasseRandomSymmetryBreakingNetConnector
benötigt // keine eigenen Kontrollparameter und wird daher mit dem Wert null aufgerufen. lNetConnectionAlgo.connectNet(myNet, null); // Nun ist das Netz verbunden und initialisiert. //INetTrainingAlgorithm
// LernverfahrenINetTrainingAlgorithm
mit Netz verbinden. lNetTrainAlgo.setNet(myNet); // 3. Trainer konfigurieren // Dem Trainer das Lernverfahren und das MLP mitteilen lTrainer.setTrainingAlgorithm(lNetTrainAlgo); // 100 Lernschritte lang trainieren lTrainer.setTargetCycles(100); // oder bis Fehlerwert unter 0.01 lTrainer.setTargetError(0.01); // Fehlerwert soll vom Typ SSE sein lTrainer.setTargetErrorType( NetPerformanceStatistics.ERRORTYPE_averageSumOfSquaredError); // Lerndaten immer in derselben Reihenfolge präsentieren lTrainer.setUseRandomizedPatternOrder(false); // Dieses Programm als Observer des Trainers anmelden lTrainer.getObserverManager().addObserver(this); // 4. Trainieren des Netzes // TRAININGSDATEN // Trainingsdaten bereitstellen: // Trainingsdaten in ein Patterns Objekt übertragen // Komforbaler ist es, die Lerndaten in einer Textdatei // zu speichern und hier zu laden Patterns lTrippleXorPat = new Patterns(3, 1, 8); // {-1,-1,-1} lTrippleXorPat.setInputToken(0, 0, -1); lTrippleXorPat.setInputToken(1, 0, -1); lTrippleXorPat.setInputToken(2, 0, -1); // {-1} lTrippleXorPat.setOutputToken(0, 0, -1); // {-1,-1,1} lTrippleXorPat.setInputToken(0, 1, -1); lTrippleXorPat.setInputToken(1, 1, -1); lTrippleXorPat.setInputToken(2, 1, 1); // {1} lTrippleXorPat.setOutputToken(0, 1, 1); // {-1,1,-1} lTrippleXorPat.setInputToken(0, 2, -1); lTrippleXorPat.setInputToken(1, 2, 1); lTrippleXorPat.setInputToken(2, 2, -1); // {1} lTrippleXorPat.setOutputToken(0, 2, 1); // {-1,1,1} lTrippleXorPat.setInputToken(0, 3, -1); lTrippleXorPat.setInputToken(1, 3, 1); lTrippleXorPat.setInputToken(2, 3, 1); // {-1} lTrippleXorPat.setOutputToken(0, 3, -1); // {1,-1,-1} lTrippleXorPat.setInputToken(0, 4, 1); lTrippleXorPat.setInputToken(1, 4, -1); lTrippleXorPat.setInputToken(2, 4, -1); // {1} lTrippleXorPat.setOutputToken(0, 4, 1); // {1,-1,1} lTrippleXorPat.setInputToken(0, 5, 1); lTrippleXorPat.setInputToken(1, 5, -1); lTrippleXorPat.setInputToken(2, 5, 1); // {-1} lTrippleXorPat.setOutputToken(0, 5, -1); // {1,1,-1} lTrippleXorPat.setInputToken(0, 6, 1); lTrippleXorPat.setInputToken(1, 6, 1); lTrippleXorPat.setInputToken(2, 6, -1); // {-1} lTrippleXorPat.setOutputToken(0, 6, -1); // {1,1,1} lTrippleXorPat.setInputToken(0, 7, 1); lTrippleXorPat.setInputToken(1, 7, 1); lTrippleXorPat.setInputToken(2, 7, 1); // {1} lTrippleXorPat.setOutputToken(0, 7, 1); // Trainingsdatenmenge dem Trainer mitteilen lTrainer.setTrainingPatterns(lTrippleXorPat); // Training beginnen lTrainer.start(); } public static void main(String [] args) { try{ new TestTrippleXorLearnConcurrent(); } catch (PatternDoesNotMatchNetException e){ e.printStackTrace(); } } public void notify(Observable pObservable) { // Diese Methode wird vom NetTrainer nach jedem // Lernschritt aufgerufen. // Das übergebene Observable Objekt ist der NetTrainer selbst if (pObservable instanceof NetTrainer){ NetTrainer lTrainer = (NetTrainer) pObservable; // Prüfen, ob das Training vorbei ist if (lTrainer.hasFinished()){ System.out.println("Training fertig."); System.out.println( NetPerformanceReporter.getNetPerformance( lTrainer.getLastCalcualtedNetStatistics(), lTrainer.getTrainingAlgorithm().getCycle() ) ); System.exit(0); } else{ // Sonst alle 10 Lernschritte Status ausgeben if (lTrainer.getTrainingAlgorithm().getCycle() % 10 == 0){ System.out.println("Lernschritt: "+ lTrainer.getTrainingAlgorithm().getCycle()); System.out.println("Aktueller Netzfehler: "+ lTrainer.getTargetErrorTypeString()+" "+ lTrainer.getLastCalculatedError()); } } } } }
Net
Constructor Summary | |
NetTrainer()
|
Method Summary | |
NetPerformanceStatistics |
getLastCalcualtedNetStatistics()
Liefert die akteullen Fehlerwerte des Netzes. |
double |
getLastCalculatedError()
Liefert den Fehlerwert aus dem letzten Lernschritt. |
observerPattern.ObserverManager |
getObserverManager()
|
int |
getTargetCycles()
Liefert die Anzahl der Lernschritte, die der Trainer absolvieren soll. |
double |
getTargetError()
Liefert den Fehlerwert, dessen Unterschreiten das Training beendet. |
int |
getTargetErrorType()
Liefert den Typ des Fehlerwertes aus setTargetError(double) .
|
java.lang.String |
getTargetErrorTypeString()
Liefert die Stringrepräsentation des aktuellen Fehlertyps zurück. |
INetTrainingAlgorithm |
getTrainingAlgorithm()
Liefert das aktuell benutzte Lernverfahren. |
Patterns |
getTrainingPatterns()
Liefert die lernenden Lerndatensätze zurück. |
boolean |
hasFinished()
|
boolean |
isRunning()
Prüft, ob der nebenläufige Prozess dieses Trainers aktiv ist. |
boolean |
isUsingRandomizedPatternOrder()
Testet, ob der Trainer zum Trainieren die Lerndatensätze in gegebener oder in zufälliger Reihenfolge in jedem Lernschritt benutzt. |
void |
run()
|
void |
setTargetCycles(int pTargetCycles)
Legt fest, wieviele Lernschritte der Trainer absolvieren soll. |
void |
setTargetError(double pTargetError)
Legt fest, bei welchem Fehlerwert der Trainer das Training als erfolgreich beenden soll. |
void |
setTargetErrorType(int pTargetErrorType)
Typ des Fehlers, dessen Unterschreitung das Training beendet setTargetError(double) .
|
void |
setTrainingAlgorithm(INetTrainingAlgorithm pTrainingAlgorithm)
Legt das Lernverfahren für das Training fest. |
void |
setTrainingPatterns(Patterns pTrainingPatterns)
Legt die Lerndatenmenge fest. |
void |
setUseRandomizedPatternOrder(boolean pUseRandomizedPatternOrder)
Legt fest, ob beim Training die Lerndatensätze immer in der gleichen Reihenfolge gelernt werden sollen, oder ob nach jedem Lernschritt die Reihenfolge zufällig ermittelt werden soll. |
void |
start()
Startet den nebenläufigen Prozess zum Trainiern des Netzes. |
void |
stop()
Hält den nebenläufigen Prozess zum Trainieren des Netzes an. |
static void |
train(INetTrainingAlgorithm pTrainingAlgorithm,
double[][] pInputPatterns,
double[][] pOutputPatterns,
double pTargetError,
int pErrorType)
Trainieren bis zu einem bestimmten Fehlerwert. |
static void |
train(INetTrainingAlgorithm pTrainingAlgorithm,
double[][] pInputPatterns,
double[][] pOutputPatterns,
double pTargetError,
int pErrorType,
int pCycles)
Trainieren bis zu einem bestimmten Fehlerwert oder bis eine bestimmte Anzahl an Lernschritten absolviert wurde. |
static void |
train(INetTrainingAlgorithm pTrainingAlgorithm,
double[][] pInputPatterns,
double[][] pOutputPatterns,
int pCycles)
Trainieren bis die gegebene Anzahl an Lernschritten absolviert wurde. |
Methods inherited from class java.lang.Object |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Constructor Detail |
public NetTrainer()
Method Detail |
public void setTrainingAlgorithm(INetTrainingAlgorithm pTrainingAlgorithm) throws PatternDoesNotMatchNetException
pTrainingAlgorithm
- Das zu benutzende Lernverfahren und das enthaltene Netz.
PatternDoesNotMatchNetException
- Falls eines der gegebenen Lerndatensätze
nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.public INetTrainingAlgorithm getTrainingAlgorithm()
public void setTrainingPatterns(Patterns pTrainingPatterns) throws PatternDoesNotMatchNetException
pTrainingPatterns
- Die zu lernende Lerndatensätze.
PatternDoesNotMatchNetException
- Falls eines der gegebenen Lerndatensätze
nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.public Patterns getTrainingPatterns()
public void setTargetCycles(int pTargetCycles)
setTargetError(double)
definierte Fehlerwert
unterschritten wurde.
pTargetCycles
- Anzahl der Lernschritte, die der Trainer absolvieren soll.setTargetError(double)
public int getTargetCycles()
public void setTargetError(double pTargetError)
setTargetCycles(int)
angegebene Anzahl Lernschritte
absolviert wurden
oder der definierte Fehlerwert unterschritten wurde.NetPerformanceStatistics
sein.
pTargetError
- Fehlerwert, dessen Unterschreiten das Training beendet.setTargetCycles(int)
,
setTargetErrorType(int)
public double getTargetError()
public void setTargetErrorType(int pTargetErrorType)
setTargetError(double)
.
Als Fehlertyp kann jeder beliebige Hauptfehlertyp
aus der Klasse NetPerformanceStatistics
benutzt werden.
pTargetErrorType
- Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
setTargetError(double)
,
getTargetErrorTypeString()
public int getTargetErrorType()
setTargetError(double)
.
Der zurückgelieferte Wert entspricht einer Konstante für den
Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
NetPerformanceStatistics
getTargetErrorTypeString()
public java.lang.String getTargetErrorTypeString()
public boolean isUsingRandomizedPatternOrder()
setUseRandomizedPatternOrder(boolean)
public void setUseRandomizedPatternOrder(boolean pUseRandomizedPatternOrder)
pUseRandomizedPatternOrder
- true = zufällige Reihenfolge benutzen,
false = immer die gleiche Reihenfolge benutzen.isUsingRandomizedPatternOrder()
,
Patterns.getRandomizedPatternsOrder()
public double getLastCalculatedError()
getTargetErrorType()
ermittelt werden.
getTargetErrorType()
,
getLastCalcualtedNetStatistics()
public NetPerformanceStatistics getLastCalcualtedNetStatistics()
public boolean hasFinished()
public void start()
setTrainingAlgorithm(INetTrainingAlgorithm)
ein
Lernverfahren mit Netz und
über setTrainingPatterns(Patterns)
Lerndatensätze zugeteilt werden.
Zusätzlich sollte und Aufruf von setTargetCycles(int)
,
setTargetError(double)
, setTargetErrorType(int)
das gewünschte Lernziel festgelegt werden.
public void stop()
public boolean isRunning()
public void run()
run
in interface java.lang.Runnable
public observerPattern.ObserverManager getObserverManager()
getObserverManager
in interface observerPattern.Observable
public static void train(INetTrainingAlgorithm pTrainingAlgorithm, double[][] pInputPatterns, double[][] pOutputPatterns, int pCycles) throws PatternDoesNotMatchNetException
pTrainingAlgorithm
- Lernverfahren, das beim Training bernutzt werden soll.
Das Lernverfahren enthält auch die Referenz zum Netz, das trainiert werden soll.pInputPatterns
- Eingabemuster der Lerndatensätze.pOutputPatterns
- Ausgabemuster der Lerndatensätze (Soll-Werte).pCycles
- Anzahl der Lernschritte, die absolviert werden sollen.
PatternDoesNotMatchNetException
- Falls eines der gegebenen Lerndatensätze
nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.public static void train(INetTrainingAlgorithm pTrainingAlgorithm, double[][] pInputPatterns, double[][] pOutputPatterns, double pTargetError, int pErrorType) throws PatternDoesNotMatchNetException
NetPerformanceStatistics
unterstützt.
pTrainingAlgorithm
- Lernverfahren, das beim Training bernutzt werden soll.
Das Lernverfahren enthält auch die Referenz zum Netz, das trainiert werden soll.pInputPatterns
- Eingabemuster der Lerndatensätze.pOutputPatterns
- Ausgabemuster der Lerndatensätze (Soll-Werte).pTargetError
- Zielwert des Fehlers, bei dem das Training gestoppt werden soll.pErrorType
- Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
PatternDoesNotMatchNetException
- Falls eines der gegebenen Lerndatensätze
nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.NetPerformanceStatistics
public static void train(INetTrainingAlgorithm pTrainingAlgorithm, double[][] pInputPatterns, double[][] pOutputPatterns, double pTargetError, int pErrorType, int pCycles) throws PatternDoesNotMatchNetException
NetPerformanceStatistics
unterstützt.
pTrainingAlgorithm
- Lernverfahren, das beim Training bernutzt werden soll.
Das Lernverfahren enthält auch die Referenz zum Netz, das trainiert werden soll.pInputPatterns
- Eingabemuster der Lerndatensätze.pOutputPatterns
- Ausgabemuster der Lerndatensätze (Soll-Werte).pTargetError
- Zielwert des Fehlers, bei dem das Training gestoppt werden sollpErrorType
- Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
pCycles
- Anzahl der Lernschritte, die absolviert werden sollen.
PatternDoesNotMatchNetException
- Falls eines der gegebenen Lerndatensätze
nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.NetPerformanceStatistics
|
|||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | ||||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |